package org.net5ijy.security.web.filter;

import java.io.IOException;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.net5ijy.security.config.security.SecurityConfig;
import org.net5ijy.security.util.RequestUtil;
import org.net5ijy.security.util.exception.ValidateCodeException;
import org.net5ijy.security.validate.ValidateCodeValidator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.web.filter.OncePerRequestFilter;

/**
 * ValidateCodeFilter
 *
 * @author XGF
 * @date 2020/11/30 22:21
 */
public class ValidateCodeFilter extends OncePerRequestFilter {

  private static Logger log = LoggerFactory
      .getLogger(ValidateCodeFilter.class);

  private Map<String, ValidateCodeValidator> validateCodeValidators;

  private AuthenticationFailureHandler authenticationFailureHandler;

  private Set<String> urls;

  public ValidateCodeFilter(
      Map<String, ValidateCodeValidator> validateCodeValidators,
      AuthenticationFailureHandler authenticationFailureHandler) {
    super();
    this.validateCodeValidators = validateCodeValidators;
    this.authenticationFailureHandler = authenticationFailureHandler;

    // 初始化需要拦截的URL集合
    String defaultLogin = SecurityConfig.LOGIN_PROCESSING_URL;
    this.urls = new HashSet<String>();
    this.urls.add(defaultLogin);
  }

  @Override
  protected void doFilterInternal(HttpServletRequest request,
      HttpServletResponse response, FilterChain filterChain)
      throws ServletException, IOException {

    String uri = RequestUtil.getRequestUri(request);

    if (urls.contains(uri) && request.getMethod().equalsIgnoreCase("POST")) {

      if (log.isDebugEnabled()) {
        log.debug("内置验证码验证器: " + validateCodeValidators.toString());
      }

      String[] types = request
          .getParameterValues(ValidateCodeValidator.VALIDATE_TYPE_PARAMETER_NAME);

      for (String type : types) {
        String k = type + ValidateCodeValidator.VALIDATOR_KEY_SUFFIX;
        ValidateCodeValidator validator = validateCodeValidators.get(k);
        try {
          // 如果没有找到对应类型的验证器，抛出异常
          if (validator == null) {
            throw new ValidateCodeException("未找到验证器: " + k);
          }
          validator.valid(request, response);
          log.info(String.format("验证通过: %s", k));
        } catch (ValidateCodeException e) {
          authenticationFailureHandler.onAuthenticationFailure(
              request, response, e);
          return;
        }
      }
    }
    filterChain.doFilter(request, response);
  }

  public Set<String> getUrls() {
    return urls;
  }

  public void setUrls(Set<String> urls) {
    this.urls = urls;
  }
}
